package gorabbitmq

import (
	"context"
	"encoding/json"
	"fmt"
	"github.com/go-kratos/kratos/v2/encoding"
	"github.com/go-kratos/kratos/v2/log"
	"github.com/go-kratos/kratos/v2/transport"
	"github.com/wagslane/go-rabbitmq"
	"reflect"
)

type Consumer struct {
	QueueName string
	Handler   rabbitmq.Handler
	Options   []func(options *rabbitmq.ConsumerOptions)
}

var _ transport.Server = (*ConsumerServer)(nil)

type ConsumerServer struct {
	conn                  *rabbitmq.Conn
	Address               string
	Consumers             []Consumer
	globalConsumerOptions []func(options *rabbitmq.ConsumerOptions)
}

// NewRabbitMQConsumerServer 设置mq地址以及初始化多个消费者
func NewRabbitMQConsumerServer(address string, consumers []Consumer, options ...ServerOption) *ConsumerServer {
	srv := &ConsumerServer{
		Address:   address,
		Consumers: consumers,
	}
	for _, v := range options {
		v(srv)
	}
	return srv
}

// NewConsumerWithDefaultOptions 创建默认消费者配置，如：exchange、routingKey、qos、gorouting数，队列持久化
func NewConsumerWithDefaultOptions(queue, exchange, routingKey string, handler rabbitmq.Handler, options ...func(consumerOptions *rabbitmq.ConsumerOptions)) Consumer {
	opts := []func(options *rabbitmq.ConsumerOptions){
		rabbitmq.WithConsumerOptionsExchangeName(exchange),
		rabbitmq.WithConsumerOptionsBinding(rabbitmq.Binding{
			RoutingKey:     routingKey,
			BindingOptions: rabbitmq.BindingOptions{Declare: true},
		}),
		rabbitmq.WithConsumerOptionsQueueDurable,
		rabbitmq.WithConsumerOptionsQOSGlobal, //设置适用于同一个连接上所有信道上的所有现有和未来消费者
	}
	opts = append(opts, options...)
	return Consumer{
		QueueName: queue,
		Handler:   handler,
		Options:   opts,
	}
}

// AppendConsumer 添加消费者
func AppendConsumer[T any](consumers []Consumer, exchange, routingKey, queue, serviceName string, handler func(ctx context.Context, req T) error, options ...func(*rabbitmq.ConsumerOptions)) []Consumer {
	return append(consumers, NewConsumer(exchange, routingKey, queue, serviceName, handler, options...))
}

func NewConsumer[T any](exchange, routingKey, queue, serviceName string, handler func(ctx context.Context, req T) error, options ...func(*rabbitmq.ConsumerOptions)) Consumer {
	return NewConsumerWithDefaultOptions(
		fmt.Sprintf("%v-%v-%v", routingKey, serviceName, queue),
		exchange,
		routingKey,
		func(d rabbitmq.Delivery) (action rabbitmq.Action) {
			//用于处理接收到的信息
			var err error
			defer func() {
				if err != nil {
					log.Errorw("Consumer:err handler error", err.Error())
				}
			}()

			var t T
			if reflect.TypeOf(t).Kind() == reflect.Pointer {
				out := reflect.New(reflect.TypeOf(t).Elem()).Interface()
				t = out.(T)
			}
			err = json.Unmarshal(d.Body, t)
			if err != nil {
				encoding.GetCodec("json").Unmarshal(d.Body, t)
				if err != nil {
					action = rabbitmq.NackRequeue
					return
				}
			}

			err = handler(context.Background(), t)
			if err != nil {
				action = rabbitmq.NackRequeue
				return
			}

			action = rabbitmq.Ack
			return
		},
		options...,
	)
}

func (s *ConsumerServer) Start(_ context.Context) (err error) {
	s.conn, err = rabbitmq.NewConn(
		s.Address,
		rabbitmq.WithConnectionOptionsLogging,
	)
	if err != nil {
		return err
	}

	//创建消费者
	for _, v := range s.Consumers {
		go func(v Consumer) {
			options := append(s.globalConsumerOptions, v.Options...)
			_, err = rabbitmq.NewConsumer(
				s.conn,
				v.Handler,
				v.QueueName,
				options..., //添加各个消费者的配置，比如交换机名，队列名，routingKey等
			)
			if err != nil {
				panic(err)
			}
		}(v)
	}
	return
}

func (s *ConsumerServer) Stop(_ context.Context) error {
	if err := s.conn.Close(); err != nil {
		return err
	}
	return nil
}
